from copy import deepcopy
import os.path as osp

import numpy as np
import torch
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


class ReplayBuffer:

    def __init__(self, size):
        """
        Create Replay buffer.

        Args:
            size (int): max number of samples to store in the buffer
                When the buffer overflows the old memories are dropped
        """
        self._buffer = []
        self._maxsize = size
        self._types = []

    def __len__(self):
        return len(self._buffer)

    def add(self, samples):
        """
        Add samples to the memory.

        Args:
            samples (tuple): tuple of input argument
                each element is expected to be a list of torch.tensor
        """
        if self._buffer == []:
            self._buffer = [
                np.asarray(samples[k]) if isinstance(samples[k], list)
                else samples[k].detach().cpu().numpy()
                for k in range(len(samples))
            ]
            self._types = [
                'list' if isinstance(samples[k], list) else 'tensor'
                for k in range(len(samples))
            ]
        else:
            if len(self._buffer) + len(samples[0]) > self._maxsize:
                throw_away = np.random.randint(
                    0, len(self._buffer),
                    (len(self._buffer) + len(samples[0]) - self._maxsize,)
                )
                keep_inds = np.ones(len(self._buffer)).astype(bool)
                keep_inds[throw_away] = False
                self._buffer = [
                    self._buffer[k][keep_inds] for k in range(len(samples))
                ]
            for k in range(len(samples)):
                if isinstance(samples[k], list):
                    self._buffer[k] = np.concatenate((
                        self._buffer[k],
                        np.asarray(samples[k])
                    ))

                else:
                    self._buffer[k] = np.concatenate((
                        self._buffer[k],
                        samples[k].detach().cpu().numpy()
                    ))

    def sample(self, batch_size):
        """Sample a batch of experiences."""
        keep_inds = np.random.choice(
            len(self._buffer), batch_size,
            replace=len(self._buffer) < batch_size
        )
        return [
            self._buffer[k][keep_inds].tolist() if self._types[k] == 'list'
            else torch.from_numpy(self._buffer[k][keep_inds])
            for k in range(len(self._types))
        ]


class EBMTrainer:
    """Train/test models on manipulation."""

    def __init__(self, model, data_loaders, args):
        self.model = model
        self.data_loaders = data_loaders
        self.args = args
        if args.use_ema:
            self.model_ema = deepcopy(model)

        self.writer = SummaryWriter(f'runs/{args.tensorboard_dir}')
        self.optimizer = Adam(
            model.parameters(), lr=args.lr, betas=(0.0, 0.9), eps=1e-8
        )
        self.buffer = []
        if args.use_buffer:
            self.buffer = ReplayBuffer(args.buffer_size)

    def run(self):
        # Set
        start_epoch = 0
        val_acc_prev_best = -1.0

        # Load
        if osp.exists(self.args.ckpnt):
            start_epoch, val_acc_prev_best = self._load_ckpnt()
            val_acc_prev_best = -1.0

        # Eval?
        if self.args.eval:
            self.model.eval()
            self.train_test_loop('test')
            return self.model

        # Go!
        for epoch in range(start_epoch, self.args.epochs):
            print("Epoch: %d/%d" % (epoch + 1, self.args.epochs))
            self.model.train()
            # Train
            self.train_test_loop('train', epoch)
            # Validate
            print("\nValidation")
            with torch.no_grad():
                self.train_test_loop('val', epoch)
            val_acc = 0

            # Store
            if val_acc >= val_acc_prev_best:
                print("Saving Checkpoint")
                torch.save({
                    "epoch": epoch + 1,
                    "model_state_dict": self.model.state_dict(),
                    "model_ema_state_dict": (
                        self.model_ema.state_dict()
                        if self.args.use_ema else None
                    ),
                    "optimizer_state_dict": self.optimizer.state_dict(),
                    "best_acc": val_acc
                }, self.args.ckpnt)
                val_acc_prev_best = val_acc
            else:
                print("Updating Checkpoint")
                checkpoint = torch.load(self.args.ckpnt)
                checkpoint["epoch"] += 1
                torch.save(checkpoint, self.args.ckpnt)

        return self.model

    def _load_ckpnt(self):
        ckpnt = torch.load(self.args.ckpnt)
        self.model.load_state_dict(ckpnt["model_state_dict"], strict=False)
        self.optimizer.load_state_dict(ckpnt["optimizer_state_dict"])
        start_epoch = ckpnt["epoch"]
        val_acc_prev_best = ckpnt['best_acc']
        return start_epoch, val_acc_prev_best

    def _prepare_inputs(self, batch):
        return {'pos': batch, 'neg': batch}

    def _sample_from_buffer(self, inputs):
        replay_batch = self.buffer.sample(len(inputs['pos'][0]))
        replay_mask = np.random.uniform(0, 1, len(replay_batch[0])) > 0.001
        for k in range(len(replay_batch)):  # loop over fields
            if isinstance(inputs['neg'][k], list):
                inputs['neg'][k] = np.asarray(inputs['neg'][k])
                inputs['neg'][k][replay_mask] = np.asarray(replay_batch[k])
                inputs['neg'][k] = inputs['neg'][k].tolist()
            else:
                device = inputs['neg'][k].device
                inputs['neg'][k][replay_mask] = replay_batch[k].to(device)
        return inputs

    def train_test_loop(self, mode='train', epoch=1000):
        for step, ex in tqdm(enumerate(self.data_loaders[mode])):
            inputs = self._prepare_inputs(ex)

            # Load from buffer
            if len(self.buffer) > self.args.batch_size:
                inputs = self._sample_from_buffer(inputs)

            # Run Langevin dynamics
            neg, neg_kl, seq = self.langevin(inputs['neg'])

            # Save to buffer
            if self.args.use_buffer:
                self.buffer.add(neg)

            # Compute energies
            energy_pos = self.model(inputs['pos'])
            energy_neg = self.model(neg)

            # Losses
            loss = energy_pos.mean() - energy_neg.mean()
            loss = loss + ((energy_pos ** 2).mean() + (energy_neg ** 2).mean())
            if self.args.kl:
                self.model.requires_grad_(False)
                loss_kl = self.model(neg_kl).mean()
                self.model.requires_grad_(True)
            else:
                loss_kl = 0
            loss = loss + self.args.kl_coeff * loss_kl

            # Update
            if mode == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                clip_grad_norm_(self.model.parameters(), 0.5)
                self.optimizer.step()
                self.optimizer.zero_grad()
                if self.args.use_ema:
                    ema_model(self.model, self.model, mu=0.999)

            # Logging
            self.writer.add_scalar(
                'Positive_energy_avg/' + mode, energy_pos.mean().item(),
                epoch * len(self.data_loaders[mode]) + step
            )
            self.writer.add_scalar(
                'Negative_energy_avg/' + mode, energy_neg.mean().item(),
                epoch * len(self.data_loaders[mode]) + step
            )
            self.writer.add_scalar(
                'Energy_diff/' + mode,
                abs(energy_pos.mean().item() - energy_neg.mean().item()),
                epoch * len(self.data_loaders[mode]) + step
            )

            # Visualizations
            self._visualize(seq)

    def langevin(self, im_neg):
        """Langevin dynamics implemented for images, re-implement."""
        im_noise = torch.randn_like(im_neg).detach()
        im_negs_samples = []

        for i in range(self.args.num_steps):
            # Add noise
            im_noise.normal_()
            im_neg = im_neg + 0.005 * im_noise

            # Forward pass
            im_neg.requires_grad_(requires_grad=True)
            energy = self.model(im_neg)

            # Backward pass (gradients wrt image)
            im_grad = torch.autograd.grad([energy.sum()], [im_neg])[0]
            im_neg_kl = im_neg.copy()
            im_neg = im_neg - self.args.step_lr * im_grad

            # Compute kl image for last step
            if i == self.args.num_steps - 1:
                energy = self.model(im_neg_kl)
                im_grad = torch.autograd.grad(
                    [energy.sum()], [im_neg_kl],
                    create_graph=True
                )[0]
                im_neg_kl = im_neg_kl - self.args.step_lr * im_grad
                im_neg_kl = torch.clamp(im_neg_kl, 0, 1)

            # Detach/clamp/store
            im_neg = torch.clamp(im_neg.detach(), 0, 1)
            im_negs_samples.append(im_neg)

        return im_neg, im_neg_kl, im_negs_samples

    def _visualize(self, seq):
        pass


def ema_model(model, model_ema, mu=0.99):
    for param, param_ema in zip(model.parameters(), model_ema.parameters()):
        param_ema.data = mu * param_ema.data + (1 - mu) * param.data


def clip_grad_norm_(parameters, max_norm, norm_type=2.0):
    """
    Clip gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int or str): type of the used p-norm.
            Can be ``'inf'`` for infinity norm.

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if len(parameters) == 0:
        return torch.tensor(0.)
    device = parameters[0].grad.device
    if norm_type == 'inf':
        total_norm = max(
            p.grad.detach().abs().max().to(device) for p in parameters
        )
    else:
        total_norm = torch.norm(torch.stack([
            torch.norm(p.grad.detach(), norm_type).to(device)
            for p in parameters
        ]), norm_type)
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in parameters:
            p.grad.detach().mul_(clip_coef.to(p.grad.device))
    return total_norm
